{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "37236084",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Image Classification (CIFAR-10) on Kaggle\n",
    ":label:`sec_kaggle_cifar10`\n",
    "\n",
    "So far, we have been using high-level APIs of deep learning frameworks to directly obtain image datasets in tensor format.\n",
    "However, custom image datasets\n",
    "often come in the form of image files.\n",
    "In this section, we will start from\n",
    "raw image files,\n",
    "and organize, read, then transform them\n",
    "into tensor format step by step.\n",
    "\n",
    "We experimented with the CIFAR-10 dataset in :numref:`sec_image_augmentation`,\n",
    "which is an important dataset in computer vision.\n",
    "In this section,\n",
    "we will apply the knowledge we learned\n",
    "in previous sections\n",
    "to practice the Kaggle competition of\n",
    "CIFAR-10 image classification.\n",
    "(**The web address of the competition is https://www.kaggle.com/c/cifar-10**)\n",
    "\n",
    ":numref:`fig_kaggle_cifar10` shows the information on the competition's webpage.\n",
    "In order to submit the results,\n",
    "you need to register a Kaggle account.\n",
    "\n",
    "![CIFAR-10 image classification competition webpage information. The competition dataset can be obtained by clicking the \"Data\" tab.](../img/kaggle-cifar10.png)\n",
    ":width:`600px`\n",
    ":label:`fig_kaggle_cifar10`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e5a3fb64",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:26.260293Z",
     "iopub.status.busy": "2023-08-18T19:29:26.259545Z",
     "iopub.status.idle": "2023-08-18T19:29:29.089576Z",
     "shell.execute_reply": "2023-08-18T19:29:29.088675Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "import math\n",
    "import os\n",
    "import shutil\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torchvision\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "403ecf47",
   "metadata": {
    "origin_pos": 3
   },
   "source": [
    "## Obtaining and Organizing the Dataset\n",
    "\n",
    "The competition dataset is divided into\n",
    "a training set and a test set,\n",
    "which contain 50000 and 300000 images, respectively.\n",
    "In the test set,\n",
    "10000 images will be used for evaluation,\n",
    "while the remaining 290000 images will not\n",
    "be evaluated:\n",
    "they are included just\n",
    "to make it hard\n",
    "to cheat with\n",
    "*manually* labeled results of the test set.\n",
    "The images in this dataset\n",
    "are all png color (RGB channels) image files,\n",
    "whose height and width are both 32 pixels.\n",
    "The images cover a total of 10 categories, namely airplanes, cars, birds, cats, deer, dogs, frogs, horses, boats, and trucks.\n",
    "The upper-left corner of :numref:`fig_kaggle_cifar10` shows some images of airplanes, cars, and birds in the dataset.\n",
    "\n",
    "\n",
    "### Downloading the Dataset\n",
    "\n",
    "After logging in to Kaggle, we can click the \"Data\" tab on the CIFAR-10 image classification competition webpage shown in :numref:`fig_kaggle_cifar10` and download the dataset by clicking the \"Download All\" button.\n",
    "After unzipping the downloaded file in `../data`, and unzipping `train.7z` and `test.7z` inside it, you will find the entire dataset in the following paths:\n",
    "\n",
    "* `../data/cifar-10/train/[1-50000].png`\n",
    "* `../data/cifar-10/test/[1-300000].png`\n",
    "* `../data/cifar-10/trainLabels.csv`\n",
    "* `../data/cifar-10/sampleSubmission.csv`\n",
    "\n",
    "where the `train` and `test` directories contain the training and testing images, respectively, `trainLabels.csv` provides labels for the training images, and `sample_submission.csv` is a sample submission file.\n",
    "\n",
    "To make it easier to get started, [**we provide a small-scale sample of the dataset that\n",
    "contains the first 1000 training images and 5 random testing images.**]\n",
    "To use the full dataset of the Kaggle competition, you need to set the following `demo` variable to `False`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0d41dcd1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.095074Z",
     "iopub.status.busy": "2023-08-18T19:29:29.094404Z",
     "iopub.status.idle": "2023-08-18T19:29:29.393994Z",
     "shell.execute_reply": "2023-08-18T19:29:29.393137Z"
    },
    "origin_pos": 4,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading ../data/kaggle_cifar10_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_cifar10_tiny.zip...\n"
     ]
    }
   ],
   "source": [
    "#@save\n",
    "d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',\n",
    "                                '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')\n",
    "\n",
    "# If you use the full dataset downloaded for the Kaggle competition, set\n",
    "# `demo` to False\n",
    "demo = True\n",
    "\n",
    "if demo:\n",
    "    data_dir = d2l.download_extract('cifar10_tiny')\n",
    "else:\n",
    "    data_dir = '../data/cifar-10/'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f716217",
   "metadata": {
    "origin_pos": 5
   },
   "source": [
    "### [**Organizing the Dataset**]\n",
    "\n",
    "We need to organize datasets to facilitate model training and testing.\n",
    "Let's first read the labels from the csv file.\n",
    "The following function returns a dictionary that maps\n",
    "the non-extension part of the filename to its label.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "04bf8387",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.399003Z",
     "iopub.status.busy": "2023-08-18T19:29:29.398718Z",
     "iopub.status.idle": "2023-08-18T19:29:29.406335Z",
     "shell.execute_reply": "2023-08-18T19:29:29.405552Z"
    },
    "origin_pos": 6,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# training examples: 1000\n",
      "# classes: 10\n"
     ]
    }
   ],
   "source": [
    "#@save\n",
    "def read_csv_labels(fname):\n",
    "    \"\"\"Read `fname` to return a filename to label dictionary.\"\"\"\n",
    "    with open(fname, 'r') as f:\n",
    "        # Skip the file header line (column name)\n",
    "        lines = f.readlines()[1:]\n",
    "    tokens = [l.rstrip().split(',') for l in lines]\n",
    "    return dict(((name, label) for name, label in tokens))\n",
    "\n",
    "labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n",
    "print('# training examples:', len(labels))\n",
    "print('# classes:', len(set(labels.values())))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abdf4d16",
   "metadata": {
    "origin_pos": 7
   },
   "source": [
    "Next, we define the `reorg_train_valid` function to [**split the validation set out of the original training set.**]\n",
    "The argument `valid_ratio` in this function is the ratio of the number of examples in the validation set to the number of examples in the original training set.\n",
    "More concretely,\n",
    "let $n$ be the number of images of the class with the least examples, and $r$ be the ratio.\n",
    "The validation set will split out\n",
    "$\\max(\\lfloor nr\\rfloor,1)$ images for each class.\n",
    "Let's use `valid_ratio=0.1` as an example. Since the original training set has 50000 images,\n",
    "there will be 45000 images used for training in the path `train_valid_test/train`,\n",
    "while the other 5000 images will be split out\n",
    "as validation set in the path `train_valid_test/valid`. After organizing the dataset, images of the same class will be placed under the same folder.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0ae3357e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.411145Z",
     "iopub.status.busy": "2023-08-18T19:29:29.410869Z",
     "iopub.status.idle": "2023-08-18T19:29:29.418258Z",
     "shell.execute_reply": "2023-08-18T19:29:29.417439Z"
    },
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "def copyfile(filename, target_dir):\n",
    "    \"\"\"Copy a file into a target directory.\"\"\"\n",
    "    os.makedirs(target_dir, exist_ok=True)\n",
    "    shutil.copy(filename, target_dir)\n",
    "\n",
    "#@save\n",
    "def reorg_train_valid(data_dir, labels, valid_ratio):\n",
    "    \"\"\"Split the validation set out of the original training set.\"\"\"\n",
    "    # The number of examples of the class that has the fewest examples in the\n",
    "    # training dataset\n",
    "    n = collections.Counter(labels.values()).most_common()[-1][1]\n",
    "    # The number of examples per class for the validation set\n",
    "    n_valid_per_label = max(1, math.floor(n * valid_ratio))\n",
    "    label_count = {}\n",
    "    for train_file in os.listdir(os.path.join(data_dir, 'train')):\n",
    "        label = labels[train_file.split('.')[0]]\n",
    "        fname = os.path.join(data_dir, 'train', train_file)\n",
    "        copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n",
    "                                     'train_valid', label))\n",
    "        if label not in label_count or label_count[label] < n_valid_per_label:\n",
    "            copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n",
    "                                         'valid', label))\n",
    "            label_count[label] = label_count.get(label, 0) + 1\n",
    "        else:\n",
    "            copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n",
    "                                         'train', label))\n",
    "    return n_valid_per_label"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d2cfa19",
   "metadata": {
    "origin_pos": 9
   },
   "source": [
    "The `reorg_test` function below [**organizes the testing set for data loading during prediction.**]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "890972a8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.422565Z",
     "iopub.status.busy": "2023-08-18T19:29:29.422289Z",
     "iopub.status.idle": "2023-08-18T19:29:29.426856Z",
     "shell.execute_reply": "2023-08-18T19:29:29.426083Z"
    },
    "origin_pos": 10,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "def reorg_test(data_dir):\n",
    "    \"\"\"Organize the testing set for data loading during prediction.\"\"\"\n",
    "    for test_file in os.listdir(os.path.join(data_dir, 'test')):\n",
    "        copyfile(os.path.join(data_dir, 'test', test_file),\n",
    "                 os.path.join(data_dir, 'train_valid_test', 'test',\n",
    "                              'unknown'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e790936",
   "metadata": {
    "origin_pos": 11
   },
   "source": [
    "Finally, we use a function to [**invoke**]\n",
    "the `read_csv_labels`, `reorg_train_valid`, and `reorg_test` (**functions defined above.**)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "00f50b41",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.431514Z",
     "iopub.status.busy": "2023-08-18T19:29:29.430810Z",
     "iopub.status.idle": "2023-08-18T19:29:29.434961Z",
     "shell.execute_reply": "2023-08-18T19:29:29.434181Z"
    },
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def reorg_cifar10_data(data_dir, valid_ratio):\n",
    "    labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n",
    "    reorg_train_valid(data_dir, labels, valid_ratio)\n",
    "    reorg_test(data_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0454395",
   "metadata": {
    "origin_pos": 13
   },
   "source": [
    "Here we only set the batch size to 32 for the small-scale sample of the dataset.\n",
    "When training and testing\n",
    "the complete dataset of the Kaggle competition,\n",
    "`batch_size` should be set to a larger integer, such as 128.\n",
    "We split out 10% of the training examples as the validation set for tuning hyperparameters.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1daf58c4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.439643Z",
     "iopub.status.busy": "2023-08-18T19:29:29.438882Z",
     "iopub.status.idle": "2023-08-18T19:29:29.700309Z",
     "shell.execute_reply": "2023-08-18T19:29:29.699321Z"
    },
    "origin_pos": 14,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "batch_size = 32 if demo else 128\n",
    "valid_ratio = 0.1\n",
    "reorg_cifar10_data(data_dir, valid_ratio)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "240c5c50",
   "metadata": {
    "origin_pos": 15
   },
   "source": [
    "## [**Image Augmentation**]\n",
    "\n",
    "We use image augmentation to address overfitting.\n",
    "For example, images can be flipped horizontally at random during training.\n",
    "We can also perform standardization for the three RGB channels of color images. Below lists some of these operations that you can tweak.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "70e97f85",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.705470Z",
     "iopub.status.busy": "2023-08-18T19:29:29.704953Z",
     "iopub.status.idle": "2023-08-18T19:29:29.710326Z",
     "shell.execute_reply": "2023-08-18T19:29:29.709544Z"
    },
    "origin_pos": 17,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "transform_train = torchvision.transforms.Compose([\n",
    "    # Scale the image up to a square of 40 pixels in both height and width\n",
    "    torchvision.transforms.Resize(40),\n",
    "    # Randomly crop a square image of 40 pixels in both height and width to\n",
    "    # produce a small square of 0.64 to 1 times the area of the original\n",
    "    # image, and then scale it to a square of 32 pixels in both height and\n",
    "    # width\n",
    "    torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),\n",
    "                                                   ratio=(1.0, 1.0)),\n",
    "    torchvision.transforms.RandomHorizontalFlip(),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    # Standardize each channel of the image\n",
    "    torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n",
    "                                     [0.2023, 0.1994, 0.2010])])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7593105",
   "metadata": {
    "origin_pos": 18
   },
   "source": [
    "During testing,\n",
    "we only perform standardization on images\n",
    "so as to\n",
    "remove randomness in the evaluation results.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "be0d5428",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.714890Z",
     "iopub.status.busy": "2023-08-18T19:29:29.714292Z",
     "iopub.status.idle": "2023-08-18T19:29:29.718602Z",
     "shell.execute_reply": "2023-08-18T19:29:29.717807Z"
    },
    "origin_pos": 20,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "transform_test = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n",
    "                                     [0.2023, 0.1994, 0.2010])])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27f918e0",
   "metadata": {
    "origin_pos": 21
   },
   "source": [
    "## Reading the Dataset\n",
    "\n",
    "Next, we [**read the organized dataset consisting of raw image files**]. Each example includes an image and a label.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "056ac33a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.722917Z",
     "iopub.status.busy": "2023-08-18T19:29:29.722506Z",
     "iopub.status.idle": "2023-08-18T19:29:29.733889Z",
     "shell.execute_reply": "2023-08-18T19:29:29.733119Z"
    },
    "origin_pos": 23,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(\n",
    "    os.path.join(data_dir, 'train_valid_test', folder),\n",
    "    transform=transform_train) for folder in ['train', 'train_valid']]\n",
    "\n",
    "valid_ds, test_ds = [torchvision.datasets.ImageFolder(\n",
    "    os.path.join(data_dir, 'train_valid_test', folder),\n",
    "    transform=transform_test) for folder in ['valid', 'test']]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16747ae7",
   "metadata": {
    "origin_pos": 24
   },
   "source": [
    "During training,\n",
    "we need to [**specify all the image augmentation operations defined above**].\n",
    "When the validation set\n",
    "is used for model evaluation during hyperparameter tuning,\n",
    "no randomness from image augmentation should be introduced.\n",
    "Before final prediction,\n",
    "we train the model on the combined training set and validation set to make full use of all the labeled data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "06fa7207",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.738557Z",
     "iopub.status.busy": "2023-08-18T19:29:29.737952Z",
     "iopub.status.idle": "2023-08-18T19:29:29.743073Z",
     "shell.execute_reply": "2023-08-18T19:29:29.742323Z"
    },
    "origin_pos": 26,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "train_iter, train_valid_iter = [torch.utils.data.DataLoader(\n",
    "    dataset, batch_size, shuffle=True, drop_last=True)\n",
    "    for dataset in (train_ds, train_valid_ds)]\n",
    "\n",
    "valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,\n",
    "                                         drop_last=True)\n",
    "\n",
    "test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,\n",
    "                                        drop_last=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e84ffa3",
   "metadata": {
    "origin_pos": 27
   },
   "source": [
    "## Defining the [**Model**]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aceea7ad",
   "metadata": {
    "origin_pos": 33,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "We define the ResNet-18 model described in\n",
    ":numref:`sec_resnet`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d527425d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.747678Z",
     "iopub.status.busy": "2023-08-18T19:29:29.747059Z",
     "iopub.status.idle": "2023-08-18T19:29:29.751129Z",
     "shell.execute_reply": "2023-08-18T19:29:29.750380Z"
    },
    "origin_pos": 35,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def get_net():\n",
    "    num_classes = 10\n",
    "    net = d2l.resnet18(num_classes, 3)\n",
    "    return net\n",
    "\n",
    "loss = nn.CrossEntropyLoss(reduction=\"none\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0d9de60",
   "metadata": {
    "origin_pos": 36
   },
   "source": [
    "## Defining the [**Training Function**]\n",
    "\n",
    "We will select models and tune hyperparameters according to the model's performance on the validation set.\n",
    "In the following, we define the model training function `train`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bde40789",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.755665Z",
     "iopub.status.busy": "2023-08-18T19:29:29.755131Z",
     "iopub.status.idle": "2023-08-18T19:29:29.764392Z",
     "shell.execute_reply": "2023-08-18T19:29:29.763621Z"
    },
    "origin_pos": 38,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n",
    "          lr_decay):\n",
    "    trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,\n",
    "                              weight_decay=wd)\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)\n",
    "    num_batches, timer = len(train_iter), d2l.Timer()\n",
    "    legend = ['train loss', 'train acc']\n",
    "    if valid_iter is not None:\n",
    "        legend.append('valid acc')\n",
    "    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n",
    "                            legend=legend)\n",
    "    net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n",
    "    for epoch in range(num_epochs):\n",
    "        net.train()\n",
    "        metric = d2l.Accumulator(3)\n",
    "        for i, (features, labels) in enumerate(train_iter):\n",
    "            timer.start()\n",
    "            l, acc = d2l.train_batch_ch13(net, features, labels,\n",
    "                                          loss, trainer, devices)\n",
    "            metric.add(l, acc, labels.shape[0])\n",
    "            timer.stop()\n",
    "            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n",
    "                animator.add(epoch + (i + 1) / num_batches,\n",
    "                             (metric[0] / metric[2], metric[1] / metric[2],\n",
    "                              None))\n",
    "        if valid_iter is not None:\n",
    "            valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)\n",
    "            animator.add(epoch + 1, (None, None, valid_acc))\n",
    "        scheduler.step()\n",
    "    measures = (f'train loss {metric[0] / metric[2]:.3f}, '\n",
    "                f'train acc {metric[1] / metric[2]:.3f}')\n",
    "    if valid_iter is not None:\n",
    "        measures += f', valid acc {valid_acc:.3f}'\n",
    "    print(measures + f'\\n{metric[2] * num_epochs / timer.sum():.1f}'\n",
    "          f' examples/sec on {str(devices)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f285eced",
   "metadata": {
    "origin_pos": 39
   },
   "source": [
    "## [**Training and Validating the Model**]\n",
    "\n",
    "Now, we can train and validate the model.\n",
    "All the following hyperparameters can be tuned.\n",
    "For example, we can increase the number of epochs.\n",
    "When `lr_period` and `lr_decay` are set to 4 and 0.9, respectively, the learning rate of the optimization algorithm will be multiplied by 0.9 after every 4 epochs. Just for ease of demonstration,\n",
    "we only train 20 epochs here.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "cd4a55c7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:29.768734Z",
     "iopub.status.busy": "2023-08-18T19:29:29.768227Z",
     "iopub.status.idle": "2023-08-18T19:30:37.496878Z",
     "shell.execute_reply": "2023-08-18T19:30:37.495860Z"
    },
    "origin_pos": 41,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss 0.654, train acc 0.789, valid acc 0.438\n",
      "958.1 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"238.965625pt\" height=\"183.35625pt\" viewBox=\"0 0 238.965625 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
       " <metadata>\n",
       "  <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
       "   <cc:Work>\n",
       "    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
       "    <dc:date>2023-08-18T19:30:37.438438</dc:date>\n",
       "    <dc:format>image/svg+xml</dc:format>\n",
       "    <dc:creator>\n",
       "     <cc:Agent>\n",
       "      <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n",
       "     </cc:Agent>\n",
       "    </dc:creator>\n",
       "   </cc:Work>\n",
       "  </rdf:RDF>\n",
       " </metadata>\n",
       " <defs>\n",
       "  <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 183.35625 \n",
       "L 238.965625 183.35625 \n",
       "L 238.965625 0 \n",
       "L 0 0 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 30.103125 145.8 \n",
       "L 225.403125 145.8 \n",
       "L 225.403125 7.2 \n",
       "L 30.103125 7.2 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <path d=\"M 71.218914 145.8 \n",
       "L 71.218914 7.2 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_2\">\n",
       "      <defs>\n",
       "       <path id=\"m99b527c8b7\" d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m99b527c8b7\" x=\"71.218914\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 5 -->\n",
       "      <g transform=\"translate(68.037664 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
       "L 3169 4666 \n",
       "L 3169 4134 \n",
       "L 1269 4134 \n",
       "L 1269 2991 \n",
       "Q 1406 3038 1543 3061 \n",
       "Q 1681 3084 1819 3084 \n",
       "Q 2600 3084 3056 2656 \n",
       "Q 3513 2228 3513 1497 \n",
       "Q 3513 744 3044 326 \n",
       "Q 2575 -91 1722 -91 \n",
       "Q 1428 -91 1123 -41 \n",
       "Q 819 9 494 109 \n",
       "L 494 744 \n",
       "Q 775 591 1075 516 \n",
       "Q 1375 441 1709 441 \n",
       "Q 2250 441 2565 725 \n",
       "Q 2881 1009 2881 1497 \n",
       "Q 2881 1984 2565 2268 \n",
       "Q 2250 2553 1709 2553 \n",
       "Q 1456 2553 1204 2497 \n",
       "Q 953 2441 691 2322 \n",
       "L 691 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <path d=\"M 122.613651 145.8 \n",
       "L 122.613651 7.2 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m99b527c8b7\" x=\"122.613651\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 10 -->\n",
       "      <g transform=\"translate(116.251151 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
       "L 1825 531 \n",
       "L 1825 4091 \n",
       "L 703 3866 \n",
       "L 703 4441 \n",
       "L 1819 4666 \n",
       "L 2450 4666 \n",
       "L 2450 531 \n",
       "L 3481 531 \n",
       "L 3481 0 \n",
       "L 794 0 \n",
       "L 794 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "        <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
       "Q 1547 4250 1301 3770 \n",
       "Q 1056 3291 1056 2328 \n",
       "Q 1056 1369 1301 889 \n",
       "Q 1547 409 2034 409 \n",
       "Q 2525 409 2770 889 \n",
       "Q 3016 1369 3016 2328 \n",
       "Q 3016 3291 2770 3770 \n",
       "Q 2525 4250 2034 4250 \n",
       "z\n",
       "M 2034 4750 \n",
       "Q 2819 4750 3233 4129 \n",
       "Q 3647 3509 3647 2328 \n",
       "Q 3647 1150 3233 529 \n",
       "Q 2819 -91 2034 -91 \n",
       "Q 1250 -91 836 529 \n",
       "Q 422 1150 422 2328 \n",
       "Q 422 3509 836 4129 \n",
       "Q 1250 4750 2034 4750 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <path d=\"M 174.008388 145.8 \n",
       "L 174.008388 7.2 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m99b527c8b7\" x=\"174.008388\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 15 -->\n",
       "      <g transform=\"translate(167.645888 160.398438) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"63.623047\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <path d=\"M 225.403125 145.8 \n",
       "L 225.403125 7.2 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m99b527c8b7\" x=\"225.403125\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 20 -->\n",
       "      <g transform=\"translate(219.040625 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
       "L 3431 531 \n",
       "L 3431 0 \n",
       "L 469 0 \n",
       "L 469 531 \n",
       "Q 828 903 1448 1529 \n",
       "Q 2069 2156 2228 2338 \n",
       "Q 2531 2678 2651 2914 \n",
       "Q 2772 3150 2772 3378 \n",
       "Q 2772 3750 2511 3984 \n",
       "Q 2250 4219 1831 4219 \n",
       "Q 1534 4219 1204 4116 \n",
       "Q 875 4013 500 3803 \n",
       "L 500 4441 \n",
       "Q 881 4594 1212 4672 \n",
       "Q 1544 4750 1819 4750 \n",
       "Q 2544 4750 2975 4387 \n",
       "Q 3406 4025 3406 3419 \n",
       "Q 3406 3131 3298 2873 \n",
       "Q 3191 2616 2906 2266 \n",
       "Q 2828 2175 2409 1742 \n",
       "Q 1991 1309 1228 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-32\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_5\">\n",
       "     <!-- epoch -->\n",
       "     <g transform=\"translate(112.525 174.076563) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
       "L 3597 1613 \n",
       "L 953 1613 \n",
       "Q 991 1019 1311 708 \n",
       "Q 1631 397 2203 397 \n",
       "Q 2534 397 2845 478 \n",
       "Q 3156 559 3463 722 \n",
       "L 3463 178 \n",
       "Q 3153 47 2828 -22 \n",
       "Q 2503 -91 2169 -91 \n",
       "Q 1331 -91 842 396 \n",
       "Q 353 884 353 1716 \n",
       "Q 353 2575 817 3079 \n",
       "Q 1281 3584 2069 3584 \n",
       "Q 2775 3584 3186 3129 \n",
       "Q 3597 2675 3597 1894 \n",
       "z\n",
       "M 3022 2063 \n",
       "Q 3016 2534 2758 2815 \n",
       "Q 2500 3097 2075 3097 \n",
       "Q 1594 3097 1305 2825 \n",
       "Q 1016 2553 972 2059 \n",
       "L 3022 2063 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
       "L 1159 -1331 \n",
       "L 581 -1331 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2969 \n",
       "Q 1341 3281 1617 3432 \n",
       "Q 1894 3584 2278 3584 \n",
       "Q 2916 3584 3314 3078 \n",
       "Q 3713 2572 3713 1747 \n",
       "Q 3713 922 3314 415 \n",
       "Q 2916 -91 2278 -91 \n",
       "Q 1894 -91 1617 61 \n",
       "Q 1341 213 1159 525 \n",
       "z\n",
       "M 3116 1747 \n",
       "Q 3116 2381 2855 2742 \n",
       "Q 2594 3103 2138 3103 \n",
       "Q 1681 3103 1420 2742 \n",
       "Q 1159 2381 1159 1747 \n",
       "Q 1159 1113 1420 752 \n",
       "Q 1681 391 2138 391 \n",
       "Q 2594 391 2855 752 \n",
       "Q 3116 1113 3116 1747 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
       "Q 1497 3097 1228 2736 \n",
       "Q 959 2375 959 1747 \n",
       "Q 959 1119 1226 758 \n",
       "Q 1494 397 1959 397 \n",
       "Q 2419 397 2687 759 \n",
       "Q 2956 1122 2956 1747 \n",
       "Q 2956 2369 2687 2733 \n",
       "Q 2419 3097 1959 3097 \n",
       "z\n",
       "M 1959 3584 \n",
       "Q 2709 3584 3137 3096 \n",
       "Q 3566 2609 3566 1747 \n",
       "Q 3566 888 3137 398 \n",
       "Q 2709 -91 1959 -91 \n",
       "Q 1206 -91 779 398 \n",
       "Q 353 888 353 1747 \n",
       "Q 353 2609 779 3096 \n",
       "Q 1206 3584 1959 3584 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-63\" d=\"M 3122 3366 \n",
       "L 3122 2828 \n",
       "Q 2878 2963 2633 3030 \n",
       "Q 2388 3097 2138 3097 \n",
       "Q 1578 3097 1268 2742 \n",
       "Q 959 2388 959 1747 \n",
       "Q 959 1106 1268 751 \n",
       "Q 1578 397 2138 397 \n",
       "Q 2388 397 2633 464 \n",
       "Q 2878 531 3122 666 \n",
       "L 3122 134 \n",
       "Q 2881 22 2623 -34 \n",
       "Q 2366 -91 2075 -91 \n",
       "Q 1284 -91 818 406 \n",
       "Q 353 903 353 1747 \n",
       "Q 353 2603 823 3093 \n",
       "Q 1294 3584 2113 3584 \n",
       "Q 2378 3584 2631 3529 \n",
       "Q 2884 3475 3122 3366 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-68\" d=\"M 3513 2113 \n",
       "L 3513 0 \n",
       "L 2938 0 \n",
       "L 2938 2094 \n",
       "Q 2938 2591 2744 2837 \n",
       "Q 2550 3084 2163 3084 \n",
       "Q 1697 3084 1428 2787 \n",
       "Q 1159 2491 1159 1978 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 4863 \n",
       "L 1159 4863 \n",
       "L 1159 2956 \n",
       "Q 1366 3272 1645 3428 \n",
       "Q 1925 3584 2291 3584 \n",
       "Q 2894 3584 3203 3211 \n",
       "Q 3513 2838 3513 2113 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-65\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-70\" x=\"61.523438\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6f\" x=\"125\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"186.181641\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-68\" x=\"241.162109\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_9\">\n",
       "      <path d=\"M 30.103125 118.308666 \n",
       "L 225.403125 118.308666 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_10\">\n",
       "      <defs>\n",
       "       <path id=\"m8b88989f37\" d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m8b88989f37\" x=\"30.103125\" y=\"118.308666\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_6\">\n",
       "      <!-- 0.5 -->\n",
       "      <g transform=\"translate(7.2 122.107885) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
       "L 1344 794 \n",
       "L 1344 0 \n",
       "L 684 0 \n",
       "L 684 794 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_11\">\n",
       "      <path d=\"M 30.103125 90.053555 \n",
       "L 225.403125 90.053555 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_12\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m8b88989f37\" x=\"30.103125\" y=\"90.053555\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 1.0 -->\n",
       "      <g transform=\"translate(7.2 93.852774) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_13\">\n",
       "      <path d=\"M 30.103125 61.798444 \n",
       "L 225.403125 61.798444 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_14\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m8b88989f37\" x=\"30.103125\" y=\"61.798444\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 1.5 -->\n",
       "      <g transform=\"translate(7.2 65.597662) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_15\">\n",
       "      <path d=\"M 30.103125 33.543332 \n",
       "L 225.403125 33.543332 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_16\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m8b88989f37\" x=\"30.103125\" y=\"33.543332\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 2.0 -->\n",
       "      <g transform=\"translate(7.2 37.342551) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-32\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_17\">\n",
       "    <path d=\"M 21.659704 13.5 \n",
       "L 23.49523 14.35727 \n",
       "L 25.330757 15.146281 \n",
       "L 27.166283 17.161905 \n",
       "L 29.001809 15.723152 \n",
       "L 30.103125 17.248488 \n",
       "L 31.938651 28.860693 \n",
       "L 33.774178 26.018259 \n",
       "L 35.609704 29.609338 \n",
       "L 37.44523 29.831117 \n",
       "L 39.280757 31.866883 \n",
       "L 40.382072 32.119794 \n",
       "L 42.217599 33.790844 \n",
       "L 44.053125 35.40172 \n",
       "L 45.888651 40.566627 \n",
       "L 47.724178 40.623021 \n",
       "L 49.559704 42.315364 \n",
       "L 50.66102 41.584603 \n",
       "L 52.496546 54.327164 \n",
       "L 54.332072 51.086817 \n",
       "L 56.167599 50.951139 \n",
       "L 58.003125 49.439654 \n",
       "L 59.838651 48.368145 \n",
       "L 60.939967 47.410899 \n",
       "L 62.775493 55.498875 \n",
       "L 64.61102 53.676244 \n",
       "L 66.446546 53.580745 \n",
       "L 68.282072 54.981826 \n",
       "L 70.117599 54.007938 \n",
       "L 71.218914 53.469045 \n",
       "L 73.054441 68.462203 \n",
       "L 74.889967 66.592277 \n",
       "L 76.725493 62.101569 \n",
       "L 78.56102 62.2664 \n",
       "L 80.396546 60.610738 \n",
       "L 81.497862 59.52485 \n",
       "L 83.333388 62.256785 \n",
       "L 85.168914 64.878892 \n",
       "L 87.004441 64.294448 \n",
       "L 88.839967 64.179114 \n",
       "L 90.675493 64.813435 \n",
       "L 91.776809 63.580285 \n",
       "L 93.612336 65.364682 \n",
       "L 95.447862 63.571038 \n",
       "L 97.283388 64.33483 \n",
       "L 99.118914 64.485692 \n",
       "L 100.954441 65.291561 \n",
       "L 102.055757 64.943957 \n",
       "L 103.891283 75.86086 \n",
       "L 105.726809 72.80647 \n",
       "L 107.562336 71.185795 \n",
       "L 109.397862 71.680823 \n",
       "L 111.233388 71.699587 \n",
       "L 112.334704 71.399096 \n",
       "L 114.17023 80.170834 \n",
       "L 116.005757 75.981402 \n",
       "L 117.841283 76.712821 \n",
       "L 119.676809 77.010238 \n",
       "L 121.512336 76.024449 \n",
       "L 122.613651 75.123448 \n",
       "L 124.449178 84.011425 \n",
       "L 126.284704 82.805743 \n",
       "L 128.12023 81.14003 \n",
       "L 129.955757 80.756018 \n",
       "L 131.791283 78.708344 \n",
       "L 132.892599 76.901838 \n",
       "L 134.728125 86.780082 \n",
       "L 136.563651 86.360812 \n",
       "L 138.399178 85.919905 \n",
       "L 140.234704 85.775946 \n",
       "L 142.07023 84.643972 \n",
       "L 143.171546 84.022922 \n",
       "L 145.007072 88.507369 \n",
       "L 146.842599 89.286506 \n",
       "L 148.678125 90.609337 \n",
       "L 150.513651 88.388744 \n",
       "L 152.349178 86.216322 \n",
       "L 153.450493 85.807542 \n",
       "L 155.28602 89.895891 \n",
       "L 157.121546 91.828211 \n",
       "L 158.957072 92.930295 \n",
       "L 160.792599 91.801156 \n",
       "L 162.628125 91.353267 \n",
       "L 163.729441 90.299155 \n",
       "L 165.564967 101.013725 \n",
       "L 167.400493 97.999588 \n",
       "L 169.23602 93.823449 \n",
       "L 171.071546 93.045968 \n",
       "L 172.907072 92.81248 \n",
       "L 174.008388 92.542337 \n",
       "L 175.843914 101.612634 \n",
       "L 177.679441 102.191293 \n",
       "L 179.514967 101.16972 \n",
       "L 181.350493 98.782112 \n",
       "L 183.18602 98.273023 \n",
       "L 184.287336 98.194657 \n",
       "L 186.122862 109.242837 \n",
       "L 187.958388 108.668082 \n",
       "L 189.793914 107.725687 \n",
       "L 191.629441 105.623783 \n",
       "L 193.464967 104.144019 \n",
       "L 194.566283 103.849949 \n",
       "L 196.401809 105.930056 \n",
       "L 198.237336 105.118384 \n",
       "L 200.072862 103.730427 \n",
       "L 201.908388 102.732506 \n",
       "L 203.743914 101.334568 \n",
       "L 204.84523 100.527439 \n",
       "L 206.680757 105.743826 \n",
       "L 208.516283 105.471757 \n",
       "L 210.351809 106.359033 \n",
       "L 212.187336 106.26261 \n",
       "L 214.022862 106.412409 \n",
       "L 215.124178 105.967192 \n",
       "L 216.959704 102.295335 \n",
       "L 218.79523 108.468188 \n",
       "L 220.630757 108.556439 \n",
       "L 222.466283 110.461015 \n",
       "L 224.301809 109.545645 \n",
       "L 225.403125 109.598715 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"line2d_18\">\n",
       "    <path d=\"M 21.659704 139.5 \n",
       "L 23.49523 138.617028 \n",
       "L 25.330757 137.734056 \n",
       "L 27.166283 136.939381 \n",
       "L 29.001809 137.239591 \n",
       "L 30.103125 136.788014 \n",
       "L 31.938651 133.1426 \n",
       "L 33.774178 133.495789 \n",
       "L 35.609704 132.200763 \n",
       "L 37.44523 132.259628 \n",
       "L 39.280757 131.23538 \n",
       "L 40.382072 130.922555 \n",
       "L 42.217599 131.376655 \n",
       "L 44.053125 129.787305 \n",
       "L 45.888651 127.727037 \n",
       "L 47.724178 128.021361 \n",
       "L 49.559704 127.491578 \n",
       "L 50.66102 127.706014 \n",
       "L 52.496546 123.253311 \n",
       "L 54.332072 123.6065 \n",
       "L 56.167599 124.783796 \n",
       "L 58.003125 124.930958 \n",
       "L 59.838651 125.443082 \n",
       "L 60.939967 126.003139 \n",
       "L 62.775493 121.487366 \n",
       "L 64.61102 123.429905 \n",
       "L 66.446546 122.900122 \n",
       "L 68.282072 122.546933 \n",
       "L 70.117599 122.758846 \n",
       "L 71.218914 123.165014 \n",
       "L 73.054441 118.661855 \n",
       "L 74.889967 119.544828 \n",
       "L 76.725493 121.251907 \n",
       "L 78.56102 121.663961 \n",
       "L 80.396546 122.052469 \n",
       "L 81.497862 122.282041 \n",
       "L 83.333388 122.193744 \n",
       "L 85.168914 119.898016 \n",
       "L 87.004441 120.31007 \n",
       "L 88.839967 119.721422 \n",
       "L 90.675493 120.286524 \n",
       "L 91.776809 120.768375 \n",
       "L 93.612336 118.661855 \n",
       "L 95.447862 119.368233 \n",
       "L 97.283388 119.250503 \n",
       "L 99.118914 119.544828 \n",
       "L 100.954441 119.509509 \n",
       "L 102.055757 119.317778 \n",
       "L 103.891283 116.542722 \n",
       "L 105.726809 118.308666 \n",
       "L 107.562336 117.837748 \n",
       "L 109.397862 117.160803 \n",
       "L 111.233388 117.2491 \n",
       "L 112.334704 117.425694 \n",
       "L 114.17023 116.189533 \n",
       "L 116.005757 117.778883 \n",
       "L 117.841283 116.424992 \n",
       "L 119.676809 115.924641 \n",
       "L 121.512336 116.330809 \n",
       "L 122.613651 116.416583 \n",
       "L 124.449178 114.423589 \n",
       "L 126.284704 114.776778 \n",
       "L 128.12023 114.541318 \n",
       "L 129.955757 114.511886 \n",
       "L 131.791283 114.70614 \n",
       "L 132.892599 115.092125 \n",
       "L 134.728125 110.8917 \n",
       "L 136.563651 111.068294 \n",
       "L 138.399178 111.480348 \n",
       "L 140.234704 111.244889 \n",
       "L 142.07023 112.16318 \n",
       "L 143.171546 112.506277 \n",
       "L 145.007072 110.8917 \n",
       "L 146.842599 109.832133 \n",
       "L 148.678125 109.714403 \n",
       "L 150.513651 110.361916 \n",
       "L 152.349178 111.52744 \n",
       "L 153.450493 111.560236 \n",
       "L 155.28602 107.006622 \n",
       "L 157.121546 107.536405 \n",
       "L 158.957072 108.066189 \n",
       "L 160.792599 108.949161 \n",
       "L 162.628125 109.196393 \n",
       "L 163.729441 109.731222 \n",
       "L 165.564967 106.300244 \n",
       "L 167.400493 107.359811 \n",
       "L 169.23602 108.772566 \n",
       "L 171.071546 109.655539 \n",
       "L 172.907072 109.267031 \n",
       "L 174.008388 109.352805 \n",
       "L 175.843914 104.5343 \n",
       "L 177.679441 105.240677 \n",
       "L 179.514967 106.182514 \n",
       "L 181.350493 107.094919 \n",
       "L 183.18602 107.430449 \n",
       "L 184.287336 107.64993 \n",
       "L 186.122862 101.3556 \n",
       "L 187.958388 102.591761 \n",
       "L 189.793914 103.474733 \n",
       "L 191.629441 104.269408 \n",
       "L 193.464967 104.816851 \n",
       "L 194.566283 105.001014 \n",
       "L 196.401809 103.827922 \n",
       "L 198.237336 104.887489 \n",
       "L 200.072862 105.829326 \n",
       "L 201.908388 106.476839 \n",
       "L 203.743914 106.653433 \n",
       "L 204.84523 106.956166 \n",
       "L 206.680757 104.5343 \n",
       "L 208.516283 104.887489 \n",
       "L 210.351809 104.063381 \n",
       "L 212.187336 104.622597 \n",
       "L 214.022862 104.958126 \n",
       "L 215.124178 105.190222 \n",
       "L 216.959704 104.5343 \n",
       "L 218.79523 102.415166 \n",
       "L 220.630757 102.061977 \n",
       "L 222.466283 101.443897 \n",
       "L 224.301809 101.850064 \n",
       "L 225.403125 101.97368 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n",
       "   </g>\n",
       "   <g id=\"line2d_19\">\n",
       "    <path d=\"M 30.103125 130.670278 \n",
       "L 40.382072 122.723528 \n",
       "L 50.66102 129.787305 \n",
       "L 60.939967 125.372444 \n",
       "L 71.218914 118.308666 \n",
       "L 81.497862 127.138389 \n",
       "L 91.776809 123.6065 \n",
       "L 102.055757 124.489472 \n",
       "L 112.334704 128.904333 \n",
       "L 122.613651 124.489472 \n",
       "L 132.892599 117.425694 \n",
       "L 143.171546 121.840555 \n",
       "L 153.450493 120.074611 \n",
       "L 163.729441 127.138389 \n",
       "L 174.008388 122.723528 \n",
       "L 184.287336 122.723528 \n",
       "L 194.566283 122.723528 \n",
       "L 204.84523 122.723528 \n",
       "L 215.124178 124.489472 \n",
       "L 225.403125 121.840555 \n",
       "\" clip-path=\"url(#pb5435cb16d)\" style=\"fill: none; stroke-dasharray: 9.6,2.4,1.5,2.4; stroke-dashoffset: 0; stroke: #008000; stroke-width: 1.5\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 30.103125 145.8 \n",
       "L 30.103125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 225.403125 145.8 \n",
       "L 225.403125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 30.103125 145.8 \n",
       "L 225.403125 145.8 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 30.103125 7.2 \n",
       "L 225.403125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"legend_1\">\n",
       "    <g id=\"patch_7\">\n",
       "     <path d=\"M 140.634375 59.234375 \n",
       "L 218.403125 59.234375 \n",
       "Q 220.403125 59.234375 220.403125 57.234375 \n",
       "L 220.403125 14.2 \n",
       "Q 220.403125 12.2 218.403125 12.2 \n",
       "L 140.634375 12.2 \n",
       "Q 138.634375 12.2 138.634375 14.2 \n",
       "L 138.634375 57.234375 \n",
       "Q 138.634375 59.234375 140.634375 59.234375 \n",
       "z\n",
       "\" style=\"fill: #ffffff; opacity: 0.8; stroke: #cccccc; stroke-linejoin: miter\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_20\">\n",
       "     <path d=\"M 142.634375 20.298438 \n",
       "L 152.634375 20.298438 \n",
       "L 162.634375 20.298438 \n",
       "\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
       "    </g>\n",
       "    <g id=\"text_10\">\n",
       "     <!-- train loss -->\n",
       "     <g transform=\"translate(170.634375 23.798438) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-74\" d=\"M 1172 4494 \n",
       "L 1172 3500 \n",
       "L 2356 3500 \n",
       "L 2356 3053 \n",
       "L 1172 3053 \n",
       "L 1172 1153 \n",
       "Q 1172 725 1289 603 \n",
       "Q 1406 481 1766 481 \n",
       "L 2356 481 \n",
       "L 2356 0 \n",
       "L 1766 0 \n",
       "Q 1100 0 847 248 \n",
       "Q 594 497 594 1153 \n",
       "L 594 3053 \n",
       "L 172 3053 \n",
       "L 172 3500 \n",
       "L 594 3500 \n",
       "L 594 4494 \n",
       "L 1172 4494 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
       "Q 2534 3019 2420 3045 \n",
       "Q 2306 3072 2169 3072 \n",
       "Q 1681 3072 1420 2755 \n",
       "Q 1159 2438 1159 1844 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1341 3275 1631 3429 \n",
       "Q 1922 3584 2338 3584 \n",
       "Q 2397 3584 2469 3576 \n",
       "Q 2541 3569 2628 3553 \n",
       "L 2631 2963 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-61\" d=\"M 2194 1759 \n",
       "Q 1497 1759 1228 1600 \n",
       "Q 959 1441 959 1056 \n",
       "Q 959 750 1161 570 \n",
       "Q 1363 391 1709 391 \n",
       "Q 2188 391 2477 730 \n",
       "Q 2766 1069 2766 1631 \n",
       "L 2766 1759 \n",
       "L 2194 1759 \n",
       "z\n",
       "M 3341 1997 \n",
       "L 3341 0 \n",
       "L 2766 0 \n",
       "L 2766 531 \n",
       "Q 2569 213 2275 61 \n",
       "Q 1981 -91 1556 -91 \n",
       "Q 1019 -91 701 211 \n",
       "Q 384 513 384 1019 \n",
       "Q 384 1609 779 1909 \n",
       "Q 1175 2209 1959 2209 \n",
       "L 2766 2209 \n",
       "L 2766 2266 \n",
       "Q 2766 2663 2505 2880 \n",
       "Q 2244 3097 1772 3097 \n",
       "Q 1472 3097 1187 3025 \n",
       "Q 903 2953 641 2809 \n",
       "L 641 3341 \n",
       "Q 956 3463 1253 3523 \n",
       "Q 1550 3584 1831 3584 \n",
       "Q 2591 3584 2966 3190 \n",
       "Q 3341 2797 3341 1997 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
       "L 1178 3500 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 3500 \n",
       "z\n",
       "M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 4134 \n",
       "L 603 4134 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6e\" d=\"M 3513 2113 \n",
       "L 3513 0 \n",
       "L 2938 0 \n",
       "L 2938 2094 \n",
       "Q 2938 2591 2744 2837 \n",
       "Q 2550 3084 2163 3084 \n",
       "Q 1697 3084 1428 2787 \n",
       "Q 1159 2491 1159 1978 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1366 3272 1645 3428 \n",
       "Q 1925 3584 2291 3584 \n",
       "Q 2894 3584 3203 3211 \n",
       "Q 3513 2838 3513 2113 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-20\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
       "L 2834 2853 \n",
       "Q 2591 2978 2328 3040 \n",
       "Q 2066 3103 1784 3103 \n",
       "Q 1356 3103 1142 2972 \n",
       "Q 928 2841 928 2578 \n",
       "Q 928 2378 1081 2264 \n",
       "Q 1234 2150 1697 2047 \n",
       "L 1894 2003 \n",
       "Q 2506 1872 2764 1633 \n",
       "Q 3022 1394 3022 966 \n",
       "Q 3022 478 2636 193 \n",
       "Q 2250 -91 1575 -91 \n",
       "Q 1294 -91 989 -36 \n",
       "Q 684 19 347 128 \n",
       "L 347 722 \n",
       "Q 666 556 975 473 \n",
       "Q 1284 391 1588 391 \n",
       "Q 1994 391 2212 530 \n",
       "Q 2431 669 2431 922 \n",
       "Q 2431 1156 2273 1281 \n",
       "Q 2116 1406 1581 1522 \n",
       "L 1381 1569 \n",
       "Q 847 1681 609 1914 \n",
       "Q 372 2147 372 2553 \n",
       "Q 372 3047 722 3315 \n",
       "Q 1072 3584 1716 3584 \n",
       "Q 2034 3584 2315 3537 \n",
       "Q 2597 3491 2834 3397 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-74\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"39.208984\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-61\" x=\"80.322266\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"141.601562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6e\" x=\"169.384766\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-20\" x=\"232.763672\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6c\" x=\"264.550781\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6f\" x=\"292.333984\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"353.515625\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"405.615234\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"line2d_21\">\n",
       "     <path d=\"M 142.634375 34.976563 \n",
       "L 152.634375 34.976563 \n",
       "L 162.634375 34.976563 \n",
       "\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n",
       "    </g>\n",
       "    <g id=\"text_11\">\n",
       "     <!-- train acc -->\n",
       "     <g transform=\"translate(170.634375 38.476563) scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-74\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"39.208984\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-61\" x=\"80.322266\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"141.601562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6e\" x=\"169.384766\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-20\" x=\"232.763672\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-61\" x=\"264.550781\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"325.830078\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"380.810547\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"line2d_22\">\n",
       "     <path d=\"M 142.634375 49.654688 \n",
       "L 152.634375 49.654688 \n",
       "L 162.634375 49.654688 \n",
       "\" style=\"fill: none; stroke-dasharray: 9.6,2.4,1.5,2.4; stroke-dashoffset: 0; stroke: #008000; stroke-width: 1.5\"/>\n",
       "    </g>\n",
       "    <g id=\"text_12\">\n",
       "     <!-- valid acc -->\n",
       "     <g transform=\"translate(170.634375 53.154688) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-76\" d=\"M 191 3500 \n",
       "L 800 3500 \n",
       "L 1894 563 \n",
       "L 2988 3500 \n",
       "L 3597 3500 \n",
       "L 2284 0 \n",
       "L 1503 0 \n",
       "L 191 3500 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-64\" d=\"M 2906 2969 \n",
       "L 2906 4863 \n",
       "L 3481 4863 \n",
       "L 3481 0 \n",
       "L 2906 0 \n",
       "L 2906 525 \n",
       "Q 2725 213 2448 61 \n",
       "Q 2172 -91 1784 -91 \n",
       "Q 1150 -91 751 415 \n",
       "Q 353 922 353 1747 \n",
       "Q 353 2572 751 3078 \n",
       "Q 1150 3584 1784 3584 \n",
       "Q 2172 3584 2448 3432 \n",
       "Q 2725 3281 2906 2969 \n",
       "z\n",
       "M 947 1747 \n",
       "Q 947 1113 1208 752 \n",
       "Q 1469 391 1925 391 \n",
       "Q 2381 391 2643 752 \n",
       "Q 2906 1113 2906 1747 \n",
       "Q 2906 2381 2643 2742 \n",
       "Q 2381 3103 1925 3103 \n",
       "Q 1469 3103 1208 2742 \n",
       "Q 947 2381 947 1747 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-76\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-61\" x=\"59.179688\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6c\" x=\"120.458984\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"148.242188\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-64\" x=\"176.025391\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-20\" x=\"239.501953\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-61\" x=\"271.289062\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"332.568359\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"387.548828\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"pb5435cb16d\">\n",
       "   <rect x=\"30.103125\" y=\"7.2\" width=\"195.3\" height=\"138.6\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 350x250 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4\n",
    "lr_period, lr_decay, net = 4, 0.9, get_net()\n",
    "net(next(iter(train_iter))[0])\n",
    "train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n",
    "      lr_decay)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd51eac4",
   "metadata": {
    "origin_pos": 42
   },
   "source": [
    "## [**Classifying the Testing Set**] and Submitting Results on Kaggle\n",
    "\n",
    "After obtaining a promising model with hyperparameters,\n",
    "we use all the labeled data (including the validation set) to retrain the model and classify the testing set.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a66ef205",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:30:37.501313Z",
     "iopub.status.busy": "2023-08-18T19:30:37.500748Z",
     "iopub.status.idle": "2023-08-18T19:31:40.934103Z",
     "shell.execute_reply": "2023-08-18T19:31:40.932837Z"
    },
    "origin_pos": 44,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss 0.608, train acc 0.786\n",
      "1040.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"238.965625pt\" height=\"183.35625pt\" viewBox=\"0 0 238.965625 183.35625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
       " <metadata>\n",
       "  <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
       "   <cc:Work>\n",
       "    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
       "    <dc:date>2023-08-18T19:31:40.877905</dc:date>\n",
       "    <dc:format>image/svg+xml</dc:format>\n",
       "    <dc:creator>\n",
       "     <cc:Agent>\n",
       "      <dc:title>Matplotlib v3.7.2, https://matplotlib.org/</dc:title>\n",
       "     </cc:Agent>\n",
       "    </dc:creator>\n",
       "   </cc:Work>\n",
       "  </rdf:RDF>\n",
       " </metadata>\n",
       " <defs>\n",
       "  <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 183.35625 \n",
       "L 238.965625 183.35625 \n",
       "L 238.965625 0 \n",
       "L 0 0 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 30.103125 145.8 \n",
       "L 225.403125 145.8 \n",
       "L 225.403125 7.2 \n",
       "L 30.103125 7.2 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <path d=\"M 71.218914 145.8 \n",
       "L 71.218914 7.2 \n",
       "\" clip-path=\"url(#pd4b015918e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_2\">\n",
       "      <defs>\n",
       "       <path id=\"m9036d8c040\" d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m9036d8c040\" x=\"71.218914\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 5 -->\n",
       "      <g transform=\"translate(68.037664 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
       "L 3169 4666 \n",
       "L 3169 4134 \n",
       "L 1269 4134 \n",
       "L 1269 2991 \n",
       "Q 1406 3038 1543 3061 \n",
       "Q 1681 3084 1819 3084 \n",
       "Q 2600 3084 3056 2656 \n",
       "Q 3513 2228 3513 1497 \n",
       "Q 3513 744 3044 326 \n",
       "Q 2575 -91 1722 -91 \n",
       "Q 1428 -91 1123 -41 \n",
       "Q 819 9 494 109 \n",
       "L 494 744 \n",
       "Q 775 591 1075 516 \n",
       "Q 1375 441 1709 441 \n",
       "Q 2250 441 2565 725 \n",
       "Q 2881 1009 2881 1497 \n",
       "Q 2881 1984 2565 2268 \n",
       "Q 2250 2553 1709 2553 \n",
       "Q 1456 2553 1204 2497 \n",
       "Q 953 2441 691 2322 \n",
       "L 691 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <path d=\"M 122.613651 145.8 \n",
       "L 122.613651 7.2 \n",
       "\" clip-path=\"url(#pd4b015918e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m9036d8c040\" x=\"122.613651\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 10 -->\n",
       "      <g transform=\"translate(116.251151 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
       "L 1825 531 \n",
       "L 1825 4091 \n",
       "L 703 3866 \n",
       "L 703 4441 \n",
       "L 1819 4666 \n",
       "L 2450 4666 \n",
       "L 2450 531 \n",
       "L 3481 531 \n",
       "L 3481 0 \n",
       "L 794 0 \n",
       "L 794 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "        <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
       "Q 1547 4250 1301 3770 \n",
       "Q 1056 3291 1056 2328 \n",
       "Q 1056 1369 1301 889 \n",
       "Q 1547 409 2034 409 \n",
       "Q 2525 409 2770 889 \n",
       "Q 3016 1369 3016 2328 \n",
       "Q 3016 3291 2770 3770 \n",
       "Q 2525 4250 2034 4250 \n",
       "z\n",
       "M 2034 4750 \n",
       "Q 2819 4750 3233 4129 \n",
       "Q 3647 3509 3647 2328 \n",
       "Q 3647 1150 3233 529 \n",
       "Q 2819 -91 2034 -91 \n",
       "Q 1250 -91 836 529 \n",
       "Q 422 1150 422 2328 \n",
       "Q 422 3509 836 4129 \n",
       "Q 1250 4750 2034 4750 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <path d=\"M 174.008388 145.8 \n",
       "L 174.008388 7.2 \n",
       "\" clip-path=\"url(#pd4b015918e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m9036d8c040\" x=\"174.008388\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 15 -->\n",
       "      <g transform=\"translate(167.645888 160.398438) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"63.623047\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <path d=\"M 225.403125 145.8 \n",
       "L 225.403125 7.2 \n",
       "\" clip-path=\"url(#pd4b015918e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m9036d8c040\" x=\"225.403125\" y=\"145.8\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 20 -->\n",
       "      <g transform=\"translate(219.040625 160.398438) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
       "L 3431 531 \n",
       "L 3431 0 \n",
       "L 469 0 \n",
       "L 469 531 \n",
       "Q 828 903 1448 1529 \n",
       "Q 2069 2156 2228 2338 \n",
       "Q 2531 2678 2651 2914 \n",
       "Q 2772 3150 2772 3378 \n",
       "Q 2772 3750 2511 3984 \n",
       "Q 2250 4219 1831 4219 \n",
       "Q 1534 4219 1204 4116 \n",
       "Q 875 4013 500 3803 \n",
       "L 500 4441 \n",
       "Q 881 4594 1212 4672 \n",
       "Q 1544 4750 1819 4750 \n",
       "Q 2544 4750 2975 4387 \n",
       "Q 3406 4025 3406 3419 \n",
       "Q 3406 3131 3298 2873 \n",
       "Q 3191 2616 2906 2266 \n",
       "Q 2828 2175 2409 1742 \n",
       "Q 1991 1309 1228 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-32\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_5\">\n",
       "     <!-- epoch -->\n",
       "     <g transform=\"translate(112.525 174.076563) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
       "L 3597 1613 \n",
       "L 953 1613 \n",
       "Q 991 1019 1311 708 \n",
       "Q 1631 397 2203 397 \n",
       "Q 2534 397 2845 478 \n",
       "Q 3156 559 3463 722 \n",
       "L 3463 178 \n",
       "Q 3153 47 2828 -22 \n",
       "Q 2503 -91 2169 -91 \n",
       "Q 1331 -91 842 396 \n",
       "Q 353 884 353 1716 \n",
       "Q 353 2575 817 3079 \n",
       "Q 1281 3584 2069 3584 \n",
       "Q 2775 3584 3186 3129 \n",
       "Q 3597 2675 3597 1894 \n",
       "z\n",
       "M 3022 2063 \n",
       "Q 3016 2534 2758 2815 \n",
       "Q 2500 3097 2075 3097 \n",
       "Q 1594 3097 1305 2825 \n",
       "Q 1016 2553 972 2059 \n",
       "L 3022 2063 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
       "L 1159 -1331 \n",
       "L 581 -1331 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2969 \n",
       "Q 1341 3281 1617 3432 \n",
       "Q 1894 3584 2278 3584 \n",
       "Q 2916 3584 3314 3078 \n",
       "Q 3713 2572 3713 1747 \n",
       "Q 3713 922 3314 415 \n",
       "Q 2916 -91 2278 -91 \n",
       "Q 1894 -91 1617 61 \n",
       "Q 1341 213 1159 525 \n",
       "z\n",
       "M 3116 1747 \n",
       "Q 3116 2381 2855 2742 \n",
       "Q 2594 3103 2138 3103 \n",
       "Q 1681 3103 1420 2742 \n",
       "Q 1159 2381 1159 1747 \n",
       "Q 1159 1113 1420 752 \n",
       "Q 1681 391 2138 391 \n",
       "Q 2594 391 2855 752 \n",
       "Q 3116 1113 3116 1747 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
       "Q 1497 3097 1228 2736 \n",
       "Q 959 2375 959 1747 \n",
       "Q 959 1119 1226 758 \n",
       "Q 1494 397 1959 397 \n",
       "Q 2419 397 2687 759 \n",
       "Q 2956 1122 2956 1747 \n",
       "Q 2956 2369 2687 2733 \n",
       "Q 2419 3097 1959 3097 \n",
       "z\n",
       "M 1959 3584 \n",
       "Q 2709 3584 3137 3096 \n",
       "Q 3566 2609 3566 1747 \n",
       "Q 3566 888 3137 398 \n",
       "Q 2709 -91 1959 -91 \n",
       "Q 1206 -91 779 398 \n",
       "Q 353 888 353 1747 \n",
       "Q 353 2609 779 3096 \n",
       "Q 1206 3584 1959 3584 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-63\" d=\"M 3122 3366 \n",
       "L 3122 2828 \n",
       "Q 2878 2963 2633 3030 \n",
       "Q 2388 3097 2138 3097 \n",
       "Q 1578 3097 1268 2742 \n",
       "Q 959 2388 959 1747 \n",
       "Q 959 1106 1268 751 \n",
       "Q 1578 397 2138 397 \n",
       "Q 2388 397 2633 464 \n",
       "Q 2878 531 3122 666 \n",
       "L 3122 134 \n",
       "Q 2881 22 2623 -34 \n",
       "Q 2366 -91 2075 -91 \n",
       "Q 1284 -91 818 406 \n",
       "Q 353 903 353 1747 \n",
       "Q 353 2603 823 3093 \n",
       "Q 1294 3584 2113 3584 \n",
       "Q 2378 3584 2631 3529 \n",
       "Q 2884 3475 3122 3366 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-68\" d=\"M 3513 2113 \n",
       "L 3513 0 \n",
       "L 2938 0 \n",
       "L 2938 2094 \n",
       "Q 2938 2591 2744 2837 \n",
       "Q 2550 3084 2163 3084 \n",
       "Q 1697 3084 1428 2787 \n",
       "Q 1159 2491 1159 1978 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 4863 \n",
       "L 1159 4863 \n",
       "L 1159 2956 \n",
       "Q 1366 3272 1645 3428 \n",
       "Q 1925 3584 2291 3584 \n",
       "Q 2894 3584 3203 3211 \n",
       "Q 3513 2838 3513 2113 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-65\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-70\" x=\"61.523438\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6f\" x=\"125\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"186.181641\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-68\" x=\"241.162109\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_9\">\n",
       "      <path d=\"M 30.103125 119.608923 \n",
       "L 225.403125 119.608923 \n",
       "\" clip-path=\"url(#pd4b015918e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_10\">\n",
       "      <defs>\n",
       "       <path id=\"m516d22056b\" d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m516d22056b\" x=\"30.103125\" y=\"119.608923\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_6\">\n",
       "      <!-- 0.5 -->\n",
       "      <g transform=\"translate(7.2 123.408142) scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-2e\" d=\"M 684 794 \n",
       "L 1344 794 \n",
       "L 1344 0 \n",
       "L 684 0 \n",
       "L 684 794 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_11\">\n",
       "      <path d=\"M 30.103125 91.319391 \n",
       "L 225.403125 91.319391 \n",
       "\" clip-path=\"url(#pd4b015918e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_12\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m516d22056b\" x=\"30.103125\" y=\"91.319391\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 1.0 -->\n",
       "      <g transform=\"translate(7.2 95.11861) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_13\">\n",
       "      <path d=\"M 30.103125 63.02986 \n",
       "L 225.403125 63.02986 \n",
       "\" clip-path=\"url(#pd4b015918e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_14\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m516d22056b\" x=\"30.103125\" y=\"63.02986\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 1.5 -->\n",
       "      <g transform=\"translate(7.2 66.829079) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_15\">\n",
       "      <path d=\"M 30.103125 34.740328 \n",
       "L 225.403125 34.740328 \n",
       "\" clip-path=\"url(#pd4b015918e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_16\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m516d22056b\" x=\"30.103125\" y=\"34.740328\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 2.0 -->\n",
       "      <g transform=\"translate(7.2 38.539547) scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-32\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-2e\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"95.410156\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_17\">\n",
       "    <path d=\"M 21.813651 13.5 \n",
       "L 23.803125 13.959956 \n",
       "L 25.792599 18.314943 \n",
       "L 27.782072 20.599222 \n",
       "L 29.771546 21.403164 \n",
       "L 30.103125 21.593929 \n",
       "L 32.092599 35.031886 \n",
       "L 34.082072 37.842296 \n",
       "L 36.071546 37.00477 \n",
       "L 38.06102 36.633177 \n",
       "L 40.050493 36.107295 \n",
       "L 40.382072 36.594667 \n",
       "L 42.371546 41.554298 \n",
       "L 44.36102 39.161163 \n",
       "L 46.350493 41.192364 \n",
       "L 48.339967 41.080482 \n",
       "L 50.329441 42.009786 \n",
       "L 50.66102 41.958767 \n",
       "L 52.650493 53.113499 \n",
       "L 54.639967 50.183821 \n",
       "L 56.629441 49.600038 \n",
       "L 58.618914 50.520772 \n",
       "L 60.608388 49.177246 \n",
       "L 60.939967 48.881344 \n",
       "L 62.929441 52.968849 \n",
       "L 64.918914 55.281783 \n",
       "L 66.908388 54.726028 \n",
       "L 68.897862 56.341362 \n",
       "L 70.887336 56.58637 \n",
       "L 71.218914 56.345292 \n",
       "L 73.208388 61.609157 \n",
       "L 75.197862 59.359753 \n",
       "L 77.187336 62.856496 \n",
       "L 79.176809 61.929198 \n",
       "L 81.166283 63.690547 \n",
       "L 81.497862 63.288965 \n",
       "L 83.487336 69.099013 \n",
       "L 85.476809 63.943153 \n",
       "L 87.466283 61.18615 \n",
       "L 89.455757 60.34854 \n",
       "L 91.44523 61.821733 \n",
       "L 91.776809 62.11326 \n",
       "L 93.766283 77.243736 \n",
       "L 95.755757 70.31459 \n",
       "L 97.74523 68.37743 \n",
       "L 99.734704 69.158547 \n",
       "L 101.724178 67.814736 \n",
       "L 102.055757 67.775212 \n",
       "L 104.04523 75.947353 \n",
       "L 106.034704 71.993159 \n",
       "L 108.024178 68.901052 \n",
       "L 110.013651 69.039159 \n",
       "L 112.003125 69.827443 \n",
       "L 112.334704 69.959979 \n",
       "L 114.324178 74.311477 \n",
       "L 116.313651 77.892257 \n",
       "L 118.303125 75.422776 \n",
       "L 120.292599 75.293338 \n",
       "L 122.282072 76.029783 \n",
       "L 122.613651 75.836435 \n",
       "L 124.603125 82.386094 \n",
       "L 126.592599 78.267309 \n",
       "L 128.582072 80.658524 \n",
       "L 130.571546 80.29485 \n",
       "L 132.56102 80.430826 \n",
       "L 132.892599 80.754975 \n",
       "L 134.882072 90.327491 \n",
       "L 136.871546 88.041926 \n",
       "L 138.86102 88.022719 \n",
       "L 140.850493 86.255366 \n",
       "L 142.839967 85.398826 \n",
       "L 143.171546 85.043938 \n",
       "L 145.16102 94.995442 \n",
       "L 147.150493 87.957732 \n",
       "L 149.139967 90.024121 \n",
       "L 151.129441 89.571122 \n",
       "L 153.118914 88.161717 \n",
       "L 153.450493 87.718549 \n",
       "L 155.439967 93.474846 \n",
       "L 157.429441 92.369951 \n",
       "L 159.418914 92.248115 \n",
       "L 161.408388 91.708973 \n",
       "L 163.397862 90.647318 \n",
       "L 163.729441 90.410337 \n",
       "L 165.718914 91.844627 \n",
       "L 167.708388 93.817854 \n",
       "L 169.697862 91.792146 \n",
       "L 171.687336 93.185085 \n",
       "L 173.676809 93.779978 \n",
       "L 174.008388 93.473303 \n",
       "L 175.997862 97.453304 \n",
       "L 177.987336 97.578782 \n",
       "L 179.976809 97.264014 \n",
       "L 181.966283 97.415558 \n",
       "L 183.955757 97.514515 \n",
       "L 184.287336 97.495298 \n",
       "L 186.276809 105.289763 \n",
       "L 188.266283 103.708011 \n",
       "L 190.255757 103.578097 \n",
       "L 192.24523 102.849596 \n",
       "L 194.234704 102.482086 \n",
       "L 194.566283 102.871349 \n",
       "L 196.555757 109.233928 \n",
       "L 198.54523 106.629036 \n",
       "L 200.534704 104.750203 \n",
       "L 202.524178 106.763143 \n",
       "L 204.513651 107.023954 \n",
       "L 204.84523 106.773407 \n",
       "L 206.834704 113.577724 \n",
       "L 208.824178 110.549242 \n",
       "L 210.813651 110.636892 \n",
       "L 212.803125 111.417192 \n",
       "L 214.792599 111.003687 \n",
       "L 215.124178 110.750876 \n",
       "L 217.113651 115.630679 \n",
       "L 219.103125 115.538346 \n",
       "L 221.092599 115.43328 \n",
       "L 223.082072 114.299445 \n",
       "L 225.071546 113.351777 \n",
       "L 225.403125 113.507999 \n",
       "\" clip-path=\"url(#pd4b015918e)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"line2d_18\">\n",
       "    <path d=\"M 21.813651 139.352659 \n",
       "L 23.803125 139.5 \n",
       "L 25.792599 136.798743 \n",
       "L 27.782072 136.405832 \n",
       "L 29.771546 136.405832 \n",
       "L 30.103125 136.206209 \n",
       "L 32.092599 130.51218 \n",
       "L 34.082072 130.217497 \n",
       "L 36.071546 130.51218 \n",
       "L 38.06102 130.291168 \n",
       "L 40.050493 130.924736 \n",
       "L 40.382072 130.730816 \n",
       "L 42.371546 129.922815 \n",
       "L 44.36102 131.838252 \n",
       "L 46.350493 131.101545 \n",
       "L 48.339967 130.585851 \n",
       "L 50.329441 130.983672 \n",
       "L 50.66102 130.844886 \n",
       "L 52.650493 125.797258 \n",
       "L 54.639967 127.860036 \n",
       "L 56.629441 127.958264 \n",
       "L 58.618914 127.491683 \n",
       "L 60.608388 127.093862 \n",
       "L 60.939967 127.308695 \n",
       "L 62.929441 126.091941 \n",
       "L 64.918914 126.386623 \n",
       "L 66.908388 127.467126 \n",
       "L 68.897862 126.460294 \n",
       "L 70.887336 126.26875 \n",
       "L 71.218914 126.225023 \n",
       "L 73.208388 123.439797 \n",
       "L 75.197862 123.145115 \n",
       "L 77.187336 122.457522 \n",
       "L 79.176809 122.62942 \n",
       "L 81.166283 121.730638 \n",
       "L 81.497862 121.833302 \n",
       "L 83.487336 121.377019 \n",
       "L 85.476809 122.408408 \n",
       "L 87.466283 123.243342 \n",
       "L 89.455757 123.513468 \n",
       "L 91.44523 123.145115 \n",
       "L 91.776809 123.145115 \n",
       "L 93.766283 116.956779 \n",
       "L 95.755757 119.019558 \n",
       "L 97.74523 119.805378 \n",
       "L 99.734704 120.124618 \n",
       "L 101.724178 120.964463 \n",
       "L 102.055757 120.806665 \n",
       "L 104.04523 115.483366 \n",
       "L 106.034704 118.872217 \n",
       "L 108.024178 119.412468 \n",
       "L 110.013651 119.461582 \n",
       "L 112.003125 119.196367 \n",
       "L 112.334704 119.095605 \n",
       "L 114.324178 118.724875 \n",
       "L 116.313651 117.398803 \n",
       "L 118.303125 118.331965 \n",
       "L 120.292599 118.209181 \n",
       "L 122.282072 118.076573 \n",
       "L 122.613651 118.240075 \n",
       "L 124.603125 117.840827 \n",
       "L 126.592599 117.398803 \n",
       "L 128.582072 116.170959 \n",
       "L 130.571546 115.999061 \n",
       "L 132.56102 116.013795 \n",
       "L 132.892599 115.787555 \n",
       "L 134.882072 111.35781 \n",
       "L 136.871546 112.389199 \n",
       "L 138.86102 112.53654 \n",
       "L 140.850493 113.788941 \n",
       "L 142.839967 114.009953 \n",
       "L 143.171546 113.962424 \n",
       "L 145.16102 110.768444 \n",
       "L 147.150493 113.715271 \n",
       "L 149.139967 112.92945 \n",
       "L 151.129441 112.46287 \n",
       "L 153.118914 113.066969 \n",
       "L 153.450493 113.220964 \n",
       "L 155.439967 111.947175 \n",
       "L 157.429441 112.978564 \n",
       "L 159.418914 112.045402 \n",
       "L 161.408388 112.094516 \n",
       "L 163.397862 112.654413 \n",
       "L 163.729441 112.707646 \n",
       "L 165.718914 111.947175 \n",
       "L 167.708388 111.35781 \n",
       "L 169.697862 111.35781 \n",
       "L 171.687336 110.768444 \n",
       "L 173.676809 110.179079 \n",
       "L 174.008388 110.426232 \n",
       "L 175.997862 107.821618 \n",
       "L 177.987336 108.116301 \n",
       "L 179.976809 108.214528 \n",
       "L 181.966283 108.926678 \n",
       "L 183.955757 109.118222 \n",
       "L 184.287336 109.22849 \n",
       "L 186.276809 105.75884 \n",
       "L 188.266283 107.084912 \n",
       "L 190.255757 106.93757 \n",
       "L 192.24523 107.600606 \n",
       "L 194.234704 107.585872 \n",
       "L 194.566283 107.51743 \n",
       "L 196.555757 103.990744 \n",
       "L 198.54523 104.285427 \n",
       "L 200.534704 105.36593 \n",
       "L 202.524178 104.727451 \n",
       "L 204.513651 104.933729 \n",
       "L 204.84523 105.007874 \n",
       "L 206.834704 103.990744 \n",
       "L 208.824178 104.285427 \n",
       "L 210.813651 103.794289 \n",
       "L 212.803125 103.843403 \n",
       "L 214.792599 104.285427 \n",
       "L 215.124178 104.494556 \n",
       "L 217.113651 100.454553 \n",
       "L 219.103125 102.075307 \n",
       "L 221.092599 102.419104 \n",
       "L 223.082072 103.033026 \n",
       "L 225.071546 103.401379 \n",
       "L 225.403125 103.410885 \n",
       "\" clip-path=\"url(#pd4b015918e)\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n",
       "   </g>\n",
       "   <g id=\"line2d_19\"/>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 30.103125 145.8 \n",
       "L 30.103125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 225.403125 145.8 \n",
       "L 225.403125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 30.103125 145.8 \n",
       "L 225.403125 145.8 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 30.103125 7.2 \n",
       "L 225.403125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"legend_1\">\n",
       "    <g id=\"patch_7\">\n",
       "     <path d=\"M 140.634375 44.55625 \n",
       "L 218.403125 44.55625 \n",
       "Q 220.403125 44.55625 220.403125 42.55625 \n",
       "L 220.403125 14.2 \n",
       "Q 220.403125 12.2 218.403125 12.2 \n",
       "L 140.634375 12.2 \n",
       "Q 138.634375 12.2 138.634375 14.2 \n",
       "L 138.634375 42.55625 \n",
       "Q 138.634375 44.55625 140.634375 44.55625 \n",
       "z\n",
       "\" style=\"fill: #ffffff; opacity: 0.8; stroke: #cccccc; stroke-linejoin: miter\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_20\">\n",
       "     <path d=\"M 142.634375 20.298438 \n",
       "L 152.634375 20.298438 \n",
       "L 162.634375 20.298438 \n",
       "\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
       "    </g>\n",
       "    <g id=\"text_10\">\n",
       "     <!-- train loss -->\n",
       "     <g transform=\"translate(170.634375 23.798438) scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-74\" d=\"M 1172 4494 \n",
       "L 1172 3500 \n",
       "L 2356 3500 \n",
       "L 2356 3053 \n",
       "L 1172 3053 \n",
       "L 1172 1153 \n",
       "Q 1172 725 1289 603 \n",
       "Q 1406 481 1766 481 \n",
       "L 2356 481 \n",
       "L 2356 0 \n",
       "L 1766 0 \n",
       "Q 1100 0 847 248 \n",
       "Q 594 497 594 1153 \n",
       "L 594 3053 \n",
       "L 172 3053 \n",
       "L 172 3500 \n",
       "L 594 3500 \n",
       "L 594 4494 \n",
       "L 1172 4494 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
       "Q 2534 3019 2420 3045 \n",
       "Q 2306 3072 2169 3072 \n",
       "Q 1681 3072 1420 2755 \n",
       "Q 1159 2438 1159 1844 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1341 3275 1631 3429 \n",
       "Q 1922 3584 2338 3584 \n",
       "Q 2397 3584 2469 3576 \n",
       "Q 2541 3569 2628 3553 \n",
       "L 2631 2963 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-61\" d=\"M 2194 1759 \n",
       "Q 1497 1759 1228 1600 \n",
       "Q 959 1441 959 1056 \n",
       "Q 959 750 1161 570 \n",
       "Q 1363 391 1709 391 \n",
       "Q 2188 391 2477 730 \n",
       "Q 2766 1069 2766 1631 \n",
       "L 2766 1759 \n",
       "L 2194 1759 \n",
       "z\n",
       "M 3341 1997 \n",
       "L 3341 0 \n",
       "L 2766 0 \n",
       "L 2766 531 \n",
       "Q 2569 213 2275 61 \n",
       "Q 1981 -91 1556 -91 \n",
       "Q 1019 -91 701 211 \n",
       "Q 384 513 384 1019 \n",
       "Q 384 1609 779 1909 \n",
       "Q 1175 2209 1959 2209 \n",
       "L 2766 2209 \n",
       "L 2766 2266 \n",
       "Q 2766 2663 2505 2880 \n",
       "Q 2244 3097 1772 3097 \n",
       "Q 1472 3097 1187 3025 \n",
       "Q 903 2953 641 2809 \n",
       "L 641 3341 \n",
       "Q 956 3463 1253 3523 \n",
       "Q 1550 3584 1831 3584 \n",
       "Q 2591 3584 2966 3190 \n",
       "Q 3341 2797 3341 1997 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
       "L 1178 3500 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 3500 \n",
       "z\n",
       "M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 4134 \n",
       "L 603 4134 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6e\" d=\"M 3513 2113 \n",
       "L 3513 0 \n",
       "L 2938 0 \n",
       "L 2938 2094 \n",
       "Q 2938 2591 2744 2837 \n",
       "Q 2550 3084 2163 3084 \n",
       "Q 1697 3084 1428 2787 \n",
       "Q 1159 2491 1159 1978 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1366 3272 1645 3428 \n",
       "Q 1925 3584 2291 3584 \n",
       "Q 2894 3584 3203 3211 \n",
       "Q 3513 2838 3513 2113 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-20\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-73\" d=\"M 2834 3397 \n",
       "L 2834 2853 \n",
       "Q 2591 2978 2328 3040 \n",
       "Q 2066 3103 1784 3103 \n",
       "Q 1356 3103 1142 2972 \n",
       "Q 928 2841 928 2578 \n",
       "Q 928 2378 1081 2264 \n",
       "Q 1234 2150 1697 2047 \n",
       "L 1894 2003 \n",
       "Q 2506 1872 2764 1633 \n",
       "Q 3022 1394 3022 966 \n",
       "Q 3022 478 2636 193 \n",
       "Q 2250 -91 1575 -91 \n",
       "Q 1294 -91 989 -36 \n",
       "Q 684 19 347 128 \n",
       "L 347 722 \n",
       "Q 666 556 975 473 \n",
       "Q 1284 391 1588 391 \n",
       "Q 1994 391 2212 530 \n",
       "Q 2431 669 2431 922 \n",
       "Q 2431 1156 2273 1281 \n",
       "Q 2116 1406 1581 1522 \n",
       "L 1381 1569 \n",
       "Q 847 1681 609 1914 \n",
       "Q 372 2147 372 2553 \n",
       "Q 372 3047 722 3315 \n",
       "Q 1072 3584 1716 3584 \n",
       "Q 2034 3584 2315 3537 \n",
       "Q 2597 3491 2834 3397 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-74\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"39.208984\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-61\" x=\"80.322266\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"141.601562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6e\" x=\"169.384766\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-20\" x=\"232.763672\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6c\" x=\"264.550781\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6f\" x=\"292.333984\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"353.515625\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-73\" x=\"405.615234\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"line2d_21\">\n",
       "     <path d=\"M 142.634375 34.976562 \n",
       "L 152.634375 34.976562 \n",
       "L 162.634375 34.976562 \n",
       "\" style=\"fill: none; stroke-dasharray: 5.55,2.4; stroke-dashoffset: 0; stroke: #bf00bf; stroke-width: 1.5\"/>\n",
       "    </g>\n",
       "    <g id=\"text_11\">\n",
       "     <!-- train acc -->\n",
       "     <g transform=\"translate(170.634375 38.476562) scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-74\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"39.208984\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-61\" x=\"80.322266\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"141.601562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6e\" x=\"169.384766\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-20\" x=\"232.763672\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-61\" x=\"264.550781\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"325.830078\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"380.810547\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"pd4b015918e\">\n",
       "   <rect x=\"30.103125\" y=\"7.2\" width=\"195.3\" height=\"138.6\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 350x250 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "net, preds = get_net(), []\n",
    "net(next(iter(train_valid_iter))[0])\n",
    "train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,\n",
    "      lr_decay)\n",
    "\n",
    "for X, _ in test_iter:\n",
    "    y_hat = net(X.to(devices[0]))\n",
    "    preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())\n",
    "sorted_ids = list(range(1, len(test_ds) + 1))\n",
    "sorted_ids.sort(key=lambda x: str(x))\n",
    "df = pd.DataFrame({'id': sorted_ids, 'label': preds})\n",
    "df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])\n",
    "df.to_csv('submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a33f6c7",
   "metadata": {
    "origin_pos": 45
   },
   "source": [
    "The above code\n",
    "will generate a `submission.csv` file,\n",
    "whose format\n",
    "meets the requirement of the Kaggle competition.\n",
    "The method\n",
    "for submitting results to Kaggle\n",
    "is similar to that in :numref:`sec_kaggle_house`.\n",
    "\n",
    "## Summary\n",
    "\n",
    "* We can read datasets containing raw image files after organizing them into the required format.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe9ddf70",
   "metadata": {
    "origin_pos": 47,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "* We can use convolutional neural networks and image augmentation in an image classification competition.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed9f69e2",
   "metadata": {
    "origin_pos": 48
   },
   "source": [
    "## Exercises\n",
    "\n",
    "1. Use the complete CIFAR-10 dataset for this Kaggle competition. Set hyperparameters as `batch_size = 128`, `num_epochs = 100`, `lr = 0.1`, `lr_period = 50`, and `lr_decay = 0.1`.  See what accuracy and ranking you can achieve in this competition. Can you further improve them?\n",
    "1. What accuracy can you get when not using image augmentation?\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74fe0ee3",
   "metadata": {
    "origin_pos": 50,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/1479)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}